#!/usr/bin/env python3
# ==============================================================================
#
# Copyright (C) 2022 Sophgo Technologies Inc.  All rights reserved.
#
# TPU-MLIR is licensed under the 2-Clause BSD License except for the
# third-party components.
#
# ==============================================================================

import torch
from calibration.data_selector import DataSelector
from utils.preprocess import preprocess
from utils.mlir_parser import MlirParser
import os
import sys
import argparse
import numpy as np
import math
import gc
import copy

from datetime import datetime
import time
from multiprocessing.pool import ThreadPool
from multiprocessing import Lock

from staging.utils import logging
from staging.utils import learning_inputs
from staging.utils import ref_tensors
from staging.utils import LrScheduler
from staging.utils import CaliTable
from staging.utils import get_fixed_float_layers

from staging.lsq import LearningScale
from staging.gptq import LearningGptqWeight
from staging.adaround import LearningAdaWeight
from staging.easyquant import EasyQuant
from staging.lapq import LossAwareQuant
from staging.comq import Comq

import pymlir

pymlir.set_mem_mode("force_value_mem")


def learning_adaweight_wrap(reqs):
    cls, epoch, op, total = reqs
    return cls.learning_one(epoch, op, total)


def learning_gptweight_wrap(reqs):
    cls, epoch, op, total = reqs
    return cls.learning_one(epoch, op, total)


def learning_scale_wrap(reqs):
    cls, op, total = reqs
    return cls.learning_one(op, total)


if __name__ == '__main__':
    print("TPU-MLIR {}".format(pymlir.__version__))
    # yapf: disable
    parser = argparse.ArgumentParser(
        description="Learning the scale for quantization, run after basic quant table")
    parser.add_argument('mlir_file', help='fp32 mlir file')
    parser.add_argument(
        '--dataset', required=True, type=str, help='dataset path for mix precision searching')
    parser.add_argument(
        "--data_list", required=False, type=str, help="specify a file with inputs's absolute path for mix precision searching")
    parser.add_argument(
        "--imagenet", required=False, action='store_true', dest='imagenet', help="if dataset is imagenet")
    parser.add_argument('--input_num', required=True, type=int, default=1000,
                        help='num of input samples for quantization searching')
    parser.add_argument('--data_seg', required=False, type=int, default=1000,
                        help='num of samples to buffer data on disk, they will be re-aranged after gather all samples')
    parser.add_argument('--epoch', required=False, type=int, default=1,
                        help='num of repeat times of input_num samples for weight learning')
    parser.add_argument('--mini_batch', required=False, type=int, default=4,
                        help='batch size for learning')
    parser.add_argument('--threads', required=False, type=int, default=4,
                        help='number of working threads')
    parser.add_argument('--momentum', required=False, type=float, default=0.9,
                        help='momentum of learning')
    parser.add_argument('--nesterov', required=False, action='store_true', dest='nesterov',
                        help='use nesterov in learning')
    parser.add_argument('--weight_decay', required=False, type=float, default=0.001,
                        help='weight decay in learning')
    parser.add_argument('--lr', required=False, type=float, default=0.001,
                        help='learning rate in learning')
    parser.add_argument('--lr_scheduler', required=False, type=str, default='Cosine',
                        choices=['Fixed', 'Cosine', 'MultiStep'],
                        help='lr scheduler')
    parser.add_argument('--calibration_table', required=True,
                        help='calibration table generated by calibration or tune tool')
    parser.add_argument('--weight_table', required=False, dest='weight_cali_table', default="",
                        help='weight calibration table generated by other tune tool')
    parser.add_argument('--chip', required=False, type=str, default='bm1684x',
                        choices=['bm1684x', 'bm1688', 'cv183x',
                                 'cv182x', 'cv181x', 'cv180x'],
                        help='chip platform name')
    parser.add_argument('--opt', required=False, type=str, default='SGD',
                        choices=['SGD', 'ADAM'],
                        help='Optimizer')
    parser.add_argument('--target', type=str, default='Scale',
                        choices=['Scale', 'AdaWeight', 'GptWeight', 'EasyQuant', 'LossAwareQuant', 'Comq'],
                        help='to learn scale or weight or both')
    parser.add_argument('-o', '--output_calibration_table', required=False, default="./new_cali",
                        help='output of calibration table after learning')
    parser.add_argument('-qtable', '--qtable', required=False, default="",
                        help='qtable in which not quant layers marked')
    parser.add_argument('-excepts', '--excepts', required=False, default="",
                        help='learning excepts these layers, split with comma')
    parser.add_argument('-quant_layers', '--quant_layers', required=False, default="",
                        help='only handle these layers, list of op names')
    parser.add_argument(
        "--quantweight", required=False, action='store_true', dest='quantweight', help="if update weight in npz or export weight-th-table")
    # yapf: enable
    args = parser.parse_args()
    if args.chip != "bm1684x" and args.chip != "bm1688":
        print("only support bm1684x and bm1688 till now!")
        sys.exit(1)
    if args.data_seg > args.input_num:
        args.data_seg = args.input_num
    loger = logging()
    pool = ThreadPool(args.threads)
    scale_searcher = LearningScale(args)
    scale_searcher.loger = loger
    cali_table = CaliTable(args.calibration_table, args.output_calibration_table)
    scale_searcher.orig_scales8 = cali_table.table
    scale_searcher.orig_scales4 = cali_table.table4
    all_inputs = learning_inputs(scale_searcher.parser, args)
    num_sample = all_inputs.prepare(args.input_num)

    learn_scale = args.target == "Scale"
    learn_adaweight = args.target == "AdaWeight"
    learn_gptweight = args.target == "GptWeight"
    learn_dualthreshold = args.target == "EasyQuant"
    learn_lapq = args.target == "LossAwareQuant"
    learn_comq = args.target == "Comq"

    print(
        f'Learning Scale: {learn_scale}; Learning AdaWeight: {learn_adaweight}; Learning GptWeight: {learn_gptweight}; Learning Active and Weight Threahold: {learn_dualthreshold}'
    )
    if learn_scale:
        scale_searcher.num_sample = num_sample
        scale_searcher.ref_tensors = ref_tensors(scale_searcher, all_inputs, loger)
        scheduler = LrScheduler(args.lr, scale_searcher.num_sample, args.lr_scheduler)
        if args.opt == 'SGD':
            scale_searcher.init_sgd(scheduler, args.momentum, args.nesterov, args.weight_decay)
        else:
            scale_searcher.init_adam(scheduler, 0.9, 0.999, args.weight_decay)
        scale_searcher.learning()
        cali_table.update(scale_searcher.new_scales)
        cali_table.write()
        del scale_searcher.ref_tensors
        del scale_searcher
    if learn_adaweight:
        scheduler = LrScheduler(args.lr, num_sample * args.epoch, args.lr_scheduler)
        weight_searcher = LearningAdaWeight(args)
        weight_searcher.loger = loger
        weight_searcher.scales = cali_table.table
        weight_searcher.scales4 = cali_table.table4
        weight_searcher.num_sample = num_sample
        weight_searcher.ref_tensors = ref_tensors(weight_searcher, all_inputs, loger)
        weight_searcher.init_opt(scheduler, args.momentum, args.nesterov, args.weight_decay)
        weight_searcher.learning()
        del weight_searcher.ref_tensors
        del weight_searcher
    if learn_gptweight:
        weight_searcher = LearningGptqWeight(args)
        weight_searcher.loger = loger
        weight_searcher.scales = cali_table.table
        weight_searcher.scales4 = cali_table.table4
        weight_searcher.num_sample = num_sample
        weight_searcher.ref_tensors = ref_tensors(weight_searcher, all_inputs, loger)
        fix_float = get_fixed_float_layers(args.mlir_file, 'INT8', args.chip,
                                           args.calibration_table, args.qtable,
                                           weight_searcher.ref_tensors)
        weight_searcher.filter_fixed_floats(fix_float)
        weight_searcher.learning()
        del weight_searcher
    if learn_dualthreshold:
        equant_searcher = EasyQuant(args)
        equant_searcher.loger = loger
        equant_searcher.scales = cali_table.table
        equant_searcher.scales4 = cali_table.table4
        equant_searcher.num_sample = num_sample
        equant_searcher.ref_tensors = ref_tensors(equant_searcher, all_inputs, loger)
        equant_searcher.learning()
        cali_table.update(equant_searcher.scales4, True)
        cali_table.write()
    if learn_lapq:
        lapq_searcher = LossAwareQuant(cali_table, args)
        lapq_searcher.loger = loger
        lapq_searcher.num_sample = num_sample
        lapq_searcher.ref_tensors = ref_tensors(lapq_searcher, all_inputs, loger)
        lapq_searcher.learning()
    if learn_comq:
        comq_searcher = Comq(cali_table, args)
        comq_searcher.loger = loger
        comq_searcher.num_sample = num_sample
        comq_searcher.ref_tensors = ref_tensors(comq_searcher, all_inputs, loger)
        comq_searcher.learning()

    loger.end()
